{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#  Training on multiple GPUs with `gluon`\n",
    "\n",
    "Gluon makes it easy to implement data parallel training.\n",
    "In this notebook, we'll implement data parallel training for a convolutional neural network.\n",
    "If you'd like a finer grained view of the concepts, \n",
    "you might want to first read the previous notebook,\n",
    "[multi gpu from scratch](./multiple-gpus-scratch.ipynb) with `gluon`.\n",
    "\n",
    "To get started, let's first define a simple convolutional neural network and loss function."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import mxnet as mx\n",
    "from mxnet import nd, gluon, autograd\n",
    "net = gluon.nn.Sequential(prefix='cnn_')\n",
    "with net.name_scope():\n",
    "    net.add(gluon.nn.Conv2D(channels=20, kernel_size=3, activation='relu'))\n",
    "    net.add(gluon.nn.MaxPool2D(pool_size=(2,2), strides=(2,2)))\n",
    "    net.add(gluon.nn.Conv2D(channels=50, kernel_size=5, activation='relu'))\n",
    "    net.add(gluon.nn.MaxPool2D(pool_size=(2,2), strides=(2,2)))\n",
    "    net.add(gluon.nn.Flatten())\n",
    "    net.add(gluon.nn.Dense(128, activation=\"relu\"))\n",
    "    net.add(gluon.nn.Dense(10))\n",
    "    \n",
    "loss = gluon.loss.SoftmaxCrossEntropyLoss()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Initialize on multiple devices\n",
    "\n",
    "Gluon supports initialization of network parameters over multiple devices. We accomplish this by passing in an array of device contexts, instead of the single contexts we've used in earlier notebooks.\n",
    "When we pass in an array of contexts, the parameters are initialized \n",
    "to be identical across all of our devices."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "GPU_COUNT = 2 # increase if you have more\n",
    "ctx = [mx.gpu(i) for i in range(GPU_COUNT)]\n",
    "net.collect_params().initialize(ctx=ctx)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Given a batch of input data,\n",
    "we can split it into parts (equal to the number of contexts) \n",
    "by calling `gluon.utils.split_and_load(batch, ctx)`.\n",
    "The `split_and_load` function doesn't just split the data,\n",
    "it also loads each part onto the appropriate device context. \n",
    "\n",
    "So now when we call the forward pass on two separate parts,\n",
    "each one is computed on the appropriate corresponding device and using the version of the parameters stored there."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "[[-0.01876061 -0.02165037 -0.01293943  0.03837404 -0.00821797 -0.00911531\n",
      "   0.00416799 -0.00729158 -0.00232711 -0.00155549]\n",
      " [ 0.00441474 -0.01953595 -0.00128483  0.02768224  0.01389615 -0.01320441\n",
      "  -0.01166505 -0.00637776  0.0135425  -0.00611765]]\n",
      "<NDArray 2x10 @gpu(0)>\n",
      "\n",
      "[[ -6.78736670e-03  -8.86893831e-03  -1.04004676e-02   1.72976423e-02\n",
      "    2.26115398e-02  -6.36630831e-03  -1.54974898e-02  -1.22633884e-02\n",
      "    1.19591374e-02  -6.60043515e-05]\n",
      " [ -1.17358668e-02  -2.16879714e-02   1.71219767e-03   2.49827504e-02\n",
      "    1.16810966e-02  -9.52543691e-03  -1.03610428e-02   5.08510228e-03\n",
      "    7.06662657e-03  -9.25292261e-03]]\n",
      "<NDArray 2x10 @gpu(1)>\n"
     ]
    }
   ],
   "source": [
    "from mxnet.test_utils import get_mnist\n",
    "mnist = get_mnist()\n",
    "batch = mnist['train_data'][0:GPU_COUNT*2, :]\n",
    "data = gluon.utils.split_and_load(batch, ctx)\n",
    "print(net(data[0]))\n",
    "print(net(data[1]))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "At any time, we can access the version of the parameters stored on each device. \n",
    "Recall from the first Chapter that our weights may not actually be initialized\n",
    "when we call `initialize` because the parameter shapes may not yet be known. \n",
    "In these cases, initialization is deferred pending shape inference. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "=== channel 0 of the first conv on gpu(0) ===\n",
      "[[[ 0.04118239  0.05352169 -0.04762455]\n",
      "  [ 0.06035256 -0.01528978  0.04946674]\n",
      "  [ 0.06110793 -0.00081179  0.02191102]]]\n",
      "<NDArray 1x3x3 @gpu(0)>\n",
      "=== channel 0 of the first conv on gpu(1) ===\n",
      "[[[ 0.04118239  0.05352169 -0.04762455]\n",
      "  [ 0.06035256 -0.01528978  0.04946674]\n",
      "  [ 0.06110793 -0.00081179  0.02191102]]]\n",
      "<NDArray 1x3x3 @gpu(1)>\n"
     ]
    }
   ],
   "source": [
    "weight = net.collect_params()['cnn_conv0_weight']\n",
    "\n",
    "for c in ctx:\n",
    "    print('=== channel 0 of the first conv on {} ==={}'.format(\n",
    "        c, weight.data(ctx=c)[0]))\n",
    "    "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Similarly, we can access the gradients on each of the GPUs. Because each GPU gets a different part of the batch (a different subset of examples), the gradients on each GPU vary. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "=== grad of channel 0 of the first conv2d on gpu(0) ===\n",
      "[[[-0.02078936 -0.00562428  0.01711007]\n",
      "  [ 0.01138539  0.0280002   0.04094725]\n",
      "  [ 0.00993335  0.01218192  0.02122578]]]\n",
      "<NDArray 1x3x3 @gpu(0)>\n",
      "=== grad of channel 0 of the first conv2d on gpu(1) ===\n",
      "[[[-0.02543036 -0.02789939 -0.00302115]\n",
      "  [-0.04816786 -0.03347274 -0.00403483]\n",
      "  [-0.03178394 -0.01254033  0.00855637]]]\n",
      "<NDArray 1x3x3 @gpu(1)>\n"
     ]
    }
   ],
   "source": [
    "def forward_backward(net, data, label):\n",
    "    with autograd.record():\n",
    "        losses = [loss(net(X), Y) for X, Y in zip(data, label)]\n",
    "    for l in losses:\n",
    "        l.backward()\n",
    "        \n",
    "label = gluon.utils.split_and_load(mnist['train_label'][0:4], ctx)\n",
    "forward_backward(net, data, label)\n",
    "for c in ctx:\n",
    "    print('=== grad of channel 0 of the first conv2d on {} ==={}'.format(\n",
    "        c, weight.grad(ctx=c)[0]))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Put all things together\n",
    "\n",
    "Now we can implement the remaining functions. Most of them are the same as [when we did everything by hand](./chapter07_distributed-learning/multiple-gpus-scratch.ipynb); one notable difference is that if a `gluon` trainer recognizes multi-devices, it will automatically aggregate the gradients and synchronize the parameters. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Running on [gpu(0)]\n",
      "Batch size is 64\n",
      "Epoch 0, training time = 5.0 sec\n",
      "         validation accuracy = 0.9738\n",
      "Epoch 1, training time = 4.8 sec\n",
      "         validation accuracy = 0.9841\n",
      "Epoch 2, training time = 4.7 sec\n",
      "         validation accuracy = 0.9863\n",
      "Epoch 3, training time = 4.7 sec\n",
      "         validation accuracy = 0.9868\n",
      "Epoch 4, training time = 4.7 sec\n",
      "         validation accuracy = 0.9877\n",
      "Running on [gpu(0), gpu(1)]\n",
      "Batch size is 128\n"
     ]
    }
   ],
   "source": [
    "from mxnet.io import NDArrayIter\n",
    "from time import time\n",
    "\n",
    "def train_batch(batch, ctx, net, trainer):\n",
    "    # split the data batch and load them on GPUs\n",
    "    data = gluon.utils.split_and_load(batch.data[0], ctx)\n",
    "    label = gluon.utils.split_and_load(batch.label[0], ctx)\n",
    "    # compute gradient\n",
    "    forward_backward(net, data, label)\n",
    "    # update parameters\n",
    "    trainer.step(batch.data[0].shape[0])\n",
    "    \n",
    "def valid_batch(batch, ctx, net):\n",
    "    data = batch.data[0].as_in_context(ctx[0])\n",
    "    pred = nd.argmax(net(data), axis=1)\n",
    "    return nd.sum(pred == batch.label[0].as_in_context(ctx[0])).asscalar()    \n",
    "\n",
    "def run(num_gpus, batch_size, lr):    \n",
    "    # the list of GPUs will be used\n",
    "    ctx = [mx.gpu(i) for i in range(num_gpus)]\n",
    "    print('Running on {}'.format(ctx))\n",
    "    \n",
    "    # data iterator\n",
    "    mnist = get_mnist()\n",
    "    train_data = NDArrayIter(mnist[\"train_data\"], mnist[\"train_label\"], batch_size)\n",
    "    valid_data = NDArrayIter(mnist[\"test_data\"], mnist[\"test_label\"], batch_size)\n",
    "    print('Batch size is {}'.format(batch_size))\n",
    "    \n",
    "    net.collect_params().initialize(force_reinit=True, ctx=ctx)\n",
    "    trainer = gluon.Trainer(net.collect_params(), 'sgd', {'learning_rate': lr})\n",
    "    for epoch in range(5):\n",
    "        # train\n",
    "        start = time()\n",
    "        train_data.reset()\n",
    "        for batch in train_data:\n",
    "            train_batch(batch, ctx, net, trainer)\n",
    "        nd.waitall()  # wait until all computations are finished to benchmark the time\n",
    "        print('Epoch %d, training time = %.1f sec'%(epoch, time()-start))\n",
    "        \n",
    "        # validating\n",
    "        valid_data.reset()\n",
    "        correct, num = 0.0, 0.0\n",
    "        for batch in valid_data:\n",
    "            correct += valid_batch(batch, ctx, net)\n",
    "            num += batch.data[0].shape[0]                \n",
    "        print('         validation accuracy = %.4f'%(correct/num))\n",
    "        \n",
    "run(1, 64, .3)        \n",
    "run(GPU_COUNT, 64*GPU_COUNT, .3)            "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Conclusion\n",
    "\n",
    "Both parameters and trainers in `gluon` support multi-devices. Moving from one device to multi-devices is straightforward. "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Next\n",
    "[Distributed training with multiple machines](../chapter07_distributed-learning/training-with-multiple-machines.ipynb)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "For whinges or inquiries, [open an issue on  GitHub.](https://github.com/zackchase/mxnet-the-straight-dope)"
   ]
  }
 ],
 "metadata": {
  "anaconda-cloud": {},
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.4.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
